{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Markov decision process\n",
    "\n",
    "This week's methods are all built to solve __M__arkov __D__ecision __P__rocesses. In the broadest sense, an MDP is defined by how it changes states and how rewards are computed.\n",
    "\n",
    "State transition is defined by $P(s' |s,a)$ - how likely are you to end at state $s'$ if you take action $a$ from state $s$. Now there's more than one way to define rewards, but we'll use $r(s,a,s')$ function for convenience.\n",
    "\n",
    "_This notebook is inspired by the awesome_ [CS294](https://github.com/berkeleydeeprlcourse/homework/blob/36a0b58261acde756abd55306fbe63df226bf62b/hw2/HW2.ipynb) _by Berkeley_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For starters, let's define a simple MDP from this picture:\n",
    "\n",
    "<img src=\"https://upload.wikimedia.org/wikipedia/commons/a/ad/Markov_Decision_Process.svg\" width=\"400px\" alt=\"Diagram by Waldoalvarez via Wikimedia Commons, CC BY-SA 4.0\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/setup_colab.sh -O- | bash\n",
    "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week02_value_based/mdp.py\n",
    "    !touch .setup_complete\n",
    "\n",
    "# This code creates a virtual display to draw game images on.\n",
    "# It will have no effect if your machine has a monitor.\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transition_probs = {\n",
    "    's0': {\n",
    "        'a0': {'s0': 0.5, 's2': 0.5},\n",
    "        'a1': {'s2': 1}\n",
    "    },\n",
    "    's1': {\n",
    "        'a0': {'s0': 0.7, 's1': 0.1, 's2': 0.2},\n",
    "        'a1': {'s1': 0.95, 's2': 0.05}\n",
    "    },\n",
    "    's2': {\n",
    "        'a0': {'s0': 0.4, 's2': 0.6},\n",
    "        'a1': {'s0': 0.3, 's1': 0.3, 's2': 0.4}\n",
    "    }\n",
    "}\n",
    "rewards = {\n",
    "    's1': {'a0': {'s0': +5}},\n",
    "    's2': {'a1': {'s0': -1}}\n",
    "}\n",
    "\n",
    "from mdp import MDP\n",
    "mdp = MDP(transition_probs, rewards, initial_state='s0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use MDP just as any other gym environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initial state = s0\n",
      "next_state = s2, reward = 0.0, done = False\n"
     ]
    }
   ],
   "source": [
    "print('initial state =', mdp.reset())\n",
    "next_state, reward, done, info = mdp.step('a1')\n",
    "print('next_state = %s, reward = %s, done = %s' % (next_state, reward, done))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "but it also has other methods that you'll need for Value Iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mdp.get_all_states = ('s0', 's1', 's2')\n",
      "mdp.get_possible_actions('s1') =  ('a0', 'a1')\n",
      "mdp.get_next_states('s1', 'a0') =  {'s0': 0.7, 's1': 0.1, 's2': 0.2}\n",
      "mdp.get_reward('s1', 'a0', 's0') =  5\n",
      "mdp.get_transition_prob('s1', 'a0', 's0') =  0.7\n"
     ]
    }
   ],
   "source": [
    "print(\"mdp.get_all_states =\", mdp.get_all_states())\n",
    "print(\"mdp.get_possible_actions('s1') = \", mdp.get_possible_actions('s1'))\n",
    "print(\"mdp.get_next_states('s1', 'a0') = \", mdp.get_next_states('s1', 'a0'))\n",
    "print(\"mdp.get_reward('s1', 'a0', 's0') = \", mdp.get_reward('s1', 'a0', 's0'))\n",
    "print(\"mdp.get_transition_prob('s1', 'a0', 's0') = \", mdp.get_transition_prob('s1', 'a0', 's0'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optional: Visualizing MDPs\n",
    "\n",
    "You can also visualize any MDP with the drawing fuction donated by [neer201](https://github.com/neer201).\n",
    "\n",
    "You have to install graphviz for system and for python. \n",
    "\n",
    "1. * For ubuntu just run: `sudo apt-get install graphviz` \n",
    "   * For OSX: `brew install graphviz`\n",
    "2. `pip install graphviz`\n",
    "3. restart the notebook\n",
    "\n",
    "__Note:__ Installing graphviz on some OS (esp. Windows) may be tricky. However, you can ignore this part alltogether and use the standart vizualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graphviz available: True\n"
     ]
    }
   ],
   "source": [
    "from mdp import has_graphviz\n",
    "from IPython.display import display\n",
    "print(\"Graphviz available:\", has_graphviz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: MDP Pages: 1 -->\n",
       "<svg width=\"720pt\" height=\"197pt\"\n",
       " viewBox=\"0.00 0.00 720.00 196.66\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(0.704871 0.704871) rotate(0) translate(4 275)\">\n",
       "<title>MDP</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-275 1017.46,-275 1017.46,4 -4,4\"/>\n",
       "<!-- s0 -->\n",
       "<g id=\"node1\" class=\"node\"><title>s0</title>\n",
       "<ellipse fill=\"lightgreen\" stroke=\"lightgreen\" cx=\"40\" cy=\"-103\" rx=\"36\" ry=\"36\"/>\n",
       "<ellipse fill=\"none\" stroke=\"lightgreen\" cx=\"40\" cy=\"-103\" rx=\"40\" ry=\"40\"/>\n",
       "<text text-anchor=\"middle\" x=\"40\" y=\"-96.8\" font-family=\"Arial\" font-size=\"24.00\">s0</text>\n",
       "</g>\n",
       "<!-- s0&#45;a0 -->\n",
       "<g id=\"node2\" class=\"node\"><title>s0&#45;a0</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"192.577\" cy=\"-147\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"192.577\" y=\"-142\" font-family=\"Arial\" font-size=\"20.00\">a0</text>\n",
       "</g>\n",
       "<!-- s0&#45;&gt;s0&#45;a0 -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>s0&#45;&gt;s0&#45;a0</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M79.0735,-111.577C99.3931,-116.435 124.733,-122.944 147,-130 150.338,-131.058 153.79,-132.23 157.227,-133.45\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"156.171,-136.791 166.764,-136.959 158.588,-130.221 156.171,-136.791\"/>\n",
       "</g>\n",
       "<!-- s0&#45;a1 -->\n",
       "<g id=\"node4\" class=\"node\"><title>s0&#45;a1</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"192.577\" cy=\"-220\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"192.577\" y=\"-215\" font-family=\"Arial\" font-size=\"20.00\">a1</text>\n",
       "</g>\n",
       "<!-- s0&#45;&gt;s0&#45;a1 -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>s0&#45;&gt;s0&#45;a1</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M66.9211,-132.667C76.2366,-142.537 87.136,-153.237 98,-162 117.05,-177.365 140.454,-191.87 159.091,-202.532\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"157.386,-205.588 167.817,-207.443 160.82,-199.488 157.386,-205.588\"/>\n",
       "</g>\n",
       "<!-- s0&#45;a0&#45;&gt;s0 -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>s0&#45;a0&#45;&gt;s0</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M165.016,-142.698C146.094,-139.302 120.188,-133.974 98,-127 94.1908,-125.803 90.2855,-124.447 86.3924,-123.004\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"87.3907,-119.637 76.802,-119.277 84.8555,-126.161 87.3907,-119.637\"/>\n",
       "<text text-anchor=\"middle\" x=\"122.5\" y=\"-145.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.5</text>\n",
       "</g>\n",
       "<!-- s2 -->\n",
       "<g id=\"node3\" class=\"node\"><title>s2</title>\n",
       "<ellipse fill=\"lightgreen\" stroke=\"lightgreen\" cx=\"430.154\" cy=\"-159\" rx=\"36\" ry=\"36\"/>\n",
       "<ellipse fill=\"none\" stroke=\"lightgreen\" cx=\"430.154\" cy=\"-159\" rx=\"40\" ry=\"40\"/>\n",
       "<text text-anchor=\"middle\" x=\"430.154\" y=\"-152.8\" font-family=\"Arial\" font-size=\"24.00\">s2</text>\n",
       "</g>\n",
       "<!-- s0&#45;a0&#45;&gt;s2 -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>s0&#45;a0&#45;&gt;s2</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M220.143,-148.353C258.819,-150.323 331.305,-154.016 379.91,-156.492\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"380.036,-160.002 390.201,-157.016 380.392,-153.011 380.036,-160.002\"/>\n",
       "<text text-anchor=\"middle\" x=\"305.154\" y=\"-162.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.5</text>\n",
       "</g>\n",
       "<!-- s2&#45;a0 -->\n",
       "<g id=\"node8\" class=\"node\"><title>s2&#45;a0</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"662.731\" cy=\"-129\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"662.731\" y=\"-124\" font-family=\"Arial\" font-size=\"20.00\">a0</text>\n",
       "</g>\n",
       "<!-- s2&#45;&gt;s2&#45;a0 -->\n",
       "<g id=\"edge13\" class=\"edge\"><title>s2&#45;&gt;s2&#45;a0</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M470.375,-158.605C508.288,-157.576 567.173,-154.264 617.154,-144 620.354,-143.343 623.642,-142.525 626.912,-141.613\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"628.029,-144.931 636.567,-138.658 625.98,-138.238 628.029,-144.931\"/>\n",
       "</g>\n",
       "<!-- s2&#45;a1 -->\n",
       "<g id=\"node9\" class=\"node\"><title>s2&#45;a1</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"662.731\" cy=\"-56\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"662.731\" y=\"-51\" font-family=\"Arial\" font-size=\"20.00\">a1</text>\n",
       "</g>\n",
       "<!-- s2&#45;&gt;s2&#45;a1 -->\n",
       "<g id=\"edge16\" class=\"edge\"><title>s2&#45;&gt;s2&#45;a1</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M456.579,-128.48C465.696,-119.154 476.6,-109.623 488.154,-103 531.419,-78.2 588.27,-66.0502 624.91,-60.4193\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"625.742,-63.8353 635.136,-58.9365 624.738,-56.9077 625.742,-63.8353\"/>\n",
       "</g>\n",
       "<!-- s0&#45;a1&#45;&gt;s2 -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>s0&#45;a1&#45;&gt;s2</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M219.837,-214.929C254.866,-207.948 318.68,-194.468 372.154,-179 375.531,-178.023 379,-176.96 382.48,-175.848\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"383.849,-179.082 392.248,-172.624 381.655,-172.435 383.849,-179.082\"/>\n",
       "<text text-anchor=\"middle\" x=\"305.154\" y=\"-216.2\" font-family=\"Arial\" font-size=\"16.00\">p = 1</text>\n",
       "</g>\n",
       "<!-- s1 -->\n",
       "<g id=\"node5\" class=\"node\"><title>s1</title>\n",
       "<ellipse fill=\"lightgreen\" stroke=\"lightgreen\" cx=\"824.309\" cy=\"-106\" rx=\"36\" ry=\"36\"/>\n",
       "<ellipse fill=\"none\" stroke=\"lightgreen\" cx=\"824.309\" cy=\"-106\" rx=\"40\" ry=\"40\"/>\n",
       "<text text-anchor=\"middle\" x=\"824.309\" y=\"-99.8\" font-family=\"Arial\" font-size=\"24.00\">s1</text>\n",
       "</g>\n",
       "<!-- s1&#45;a0 -->\n",
       "<g id=\"node6\" class=\"node\"><title>s1&#45;a0</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"985.886\" cy=\"-70\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"985.886\" y=\"-65\" font-family=\"Arial\" font-size=\"20.00\">a0</text>\n",
       "</g>\n",
       "<!-- s1&#45;&gt;s1&#45;a0 -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>s1&#45;&gt;s1&#45;a0</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M863.531,-97.3771C889.464,-91.5267 923.525,-83.8428 948.959,-78.105\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"949.952,-81.469 958.936,-75.8541 948.411,-74.6406 949.952,-81.469\"/>\n",
       "</g>\n",
       "<!-- s1&#45;a1 -->\n",
       "<g id=\"node7\" class=\"node\"><title>s1&#45;a1</title>\n",
       "<ellipse fill=\"lightpink\" stroke=\"lightpink\" cx=\"985.886\" cy=\"-152\" rx=\"27.6545\" ry=\"27.6545\"/>\n",
       "<text text-anchor=\"middle\" x=\"985.886\" y=\"-147\" font-family=\"Arial\" font-size=\"20.00\">a1</text>\n",
       "</g>\n",
       "<!-- s1&#45;&gt;s1&#45;a1 -->\n",
       "<g id=\"edge10\" class=\"edge\"><title>s1&#45;&gt;s1&#45;a1</title>\n",
       "<path fill=\"none\" stroke=\"red\" stroke-width=\"2\" d=\"M863.837,-113.312C886.389,-118.021 915.267,-124.791 940.309,-133 943.876,-134.169 947.554,-135.508 951.196,-136.922\"/>\n",
       "<polygon fill=\"red\" stroke=\"red\" stroke-width=\"2\" points=\"950.07,-140.243 960.651,-140.775 952.712,-133.76 950.07,-140.243\"/>\n",
       "</g>\n",
       "<!-- s1&#45;a0&#45;&gt;s0 -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>s1&#45;a0&#45;&gt;s0</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M963.474,-53.7646C934.012,-33.0859 878.553,-0 825.309,-0 191.577,-0 191.577,-0 191.577,-0 144.186,-0 100.753,-35.1875 72.8408,-64.4291\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"69.9662,-62.382 65.7381,-72.0966 75.1015,-67.139 69.9662,-62.382\"/>\n",
       "<text text-anchor=\"middle\" x=\"552.654\" y=\"-5.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.7 &#160;reward =5</text>\n",
       "</g>\n",
       "<!-- s1&#45;a0&#45;&gt;s2 -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>s1&#45;a0&#45;&gt;s2</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M970.179,-92.8767C966.041,-99.8269 961.738,-107.589 958.309,-115 948.039,-137.196 960.231,-152.814 940.309,-167 867.594,-218.778 592.763,-183.887 479.943,-166.845\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"480.276,-163.355 469.862,-165.304 479.218,-170.275 480.276,-163.355\"/>\n",
       "<text text-anchor=\"middle\" x=\"737.309\" y=\"-200.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.2</text>\n",
       "</g>\n",
       "<!-- s1&#45;a0&#45;&gt;s1 -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>s1&#45;a0&#45;&gt;s1</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M959.695,-61.2337C938.465,-55.2571 907.594,-49.9199 882.309,-59 874.67,-61.7432 867.325,-66.0021 860.594,-70.8598\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"858.135,-68.3406 852.423,-77.264 862.453,-73.8502 858.135,-68.3406\"/>\n",
       "<text text-anchor=\"middle\" x=\"911.309\" y=\"-64.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.1</text>\n",
       "</g>\n",
       "<!-- s1&#45;a1&#45;&gt;s2 -->\n",
       "<g id=\"edge12\" class=\"edge\"><title>s1&#45;a1&#45;&gt;s2</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M969.759,-174.512C962.133,-183.877 952.013,-193.769 940.309,-199 756.84,-280.995 679.08,-261.7 488.154,-199 481.749,-196.897 475.43,-193.811 469.463,-190.272\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"471.306,-187.296 461.003,-184.823 467.515,-193.181 471.306,-187.296\"/>\n",
       "<text text-anchor=\"middle\" x=\"737.309\" y=\"-258.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.05</text>\n",
       "</g>\n",
       "<!-- s1&#45;a1&#45;&gt;s1 -->\n",
       "<g id=\"edge11\" class=\"edge\"><title>s1&#45;a1&#45;&gt;s1</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M958.394,-147.727C937.471,-144.016 907.674,-137.983 882.309,-130 878.5,-128.801 874.595,-127.444 870.702,-126.001\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"871.701,-122.634 861.112,-122.273 869.165,-129.158 871.701,-122.634\"/>\n",
       "<text text-anchor=\"middle\" x=\"911.309\" y=\"-150.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.95</text>\n",
       "</g>\n",
       "<!-- s2&#45;a0&#45;&gt;s0 -->\n",
       "<g id=\"edge14\" class=\"edge\"><title>s2&#45;a0&#45;&gt;s0</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M639.021,-114.018C632.261,-110.335 624.651,-106.904 617.154,-105 427.072,-56.7313 190.201,-81.2241 89.7024,-95.3394\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"88.9547,-91.9108 79.5531,-96.7954 89.9488,-98.8398 88.9547,-91.9108\"/>\n",
       "<text text-anchor=\"middle\" x=\"305.154\" y=\"-86.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.4</text>\n",
       "</g>\n",
       "<!-- s2&#45;a0&#45;&gt;s1 -->\n",
       "<g id=\"edge15\" class=\"edge\"><title>s2&#45;a0&#45;&gt;s1</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M690.082,-125.202C713.025,-121.895 746.878,-117.016 774.717,-113.004\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"775.272,-116.46 784.67,-111.569 774.273,-109.531 775.272,-116.46\"/>\n",
       "<text text-anchor=\"middle\" x=\"737.309\" y=\"-128.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.6</text>\n",
       "</g>\n",
       "<!-- s2&#45;a1&#45;&gt;s0 -->\n",
       "<g id=\"edge17\" class=\"edge\"><title>s2&#45;a1&#45;&gt;s0</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M634.867,-53.9388C567.736,-49.2213 387.339,-39.2124 238.154,-55 185.836,-60.5367 127.463,-75.9491 87.9054,-87.8304\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"86.8144,-84.5039 78.2688,-90.7669 88.8548,-91.1999 86.8144,-84.5039\"/>\n",
       "<text text-anchor=\"middle\" x=\"305.154\" y=\"-60.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.3 &#160;reward =&#45;1</text>\n",
       "</g>\n",
       "<!-- s2&#45;a1&#45;&gt;s2 -->\n",
       "<g id=\"edge19\" class=\"edge\"><title>s2&#45;a1&#45;&gt;s2</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M644.001,-76.5595C636.457,-84.0842 627.104,-91.9769 617.154,-97 565.021,-123.32 542.31,-100.144 488.154,-122 482.617,-124.235 477.042,-127.032 471.66,-130.082\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"469.588,-127.243 462.825,-135.398 473.198,-133.241 469.588,-127.243\"/>\n",
       "<text text-anchor=\"middle\" x=\"552.654\" y=\"-127.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.4</text>\n",
       "</g>\n",
       "<!-- s2&#45;a1&#45;&gt;s1 -->\n",
       "<g id=\"edge18\" class=\"edge\"><title>s2&#45;a1&#45;&gt;s1</title>\n",
       "<path fill=\"none\" stroke=\"blue\" stroke-dasharray=\"5,2\" d=\"M689.37,-64.0334C712.698,-71.3427 747.719,-82.3157 776.1,-91.2083\"/>\n",
       "<polygon fill=\"blue\" stroke=\"blue\" points=\"775.243,-94.6074 785.832,-94.2575 777.336,-87.9276 775.243,-94.6074\"/>\n",
       "<text text-anchor=\"middle\" x=\"737.309\" y=\"-92.2\" font-family=\"Arial\" font-size=\"16.00\">p = 0.3</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x7f729b9db7b8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if has_graphviz:\n",
    "    from mdp import plot_graph, plot_graph_with_state_values, plot_graph_optimal_strategy_and_state_values\n",
    "    display(plot_graph(mdp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value Iteration\n",
    "\n",
    "Now let's build something to solve this MDP. The simplest algorithm so far is __V__alue __I__teration\n",
    "\n",
    "Here's the pseudo-code for VI:\n",
    "\n",
    "---\n",
    "\n",
    "`1.` Initialize $V^{(0)}(s)=0$, for all $s$\n",
    "\n",
    "`2.` For $i=0, 1, 2, \\dots$\n",
    " \n",
    "`3.` $ \\quad V_{(i+1)}(s) = \\max_a \\sum_{s'} P(s' | s,a) \\cdot [ r(s,a,s') + \\gamma V_{i}(s')]$, for all $s$\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's write a function to compute the state-action value function $Q^{\\pi}$, defined as follows\n",
    "\n",
    "$$Q_i(s, a) = \\sum_{s'} P(s' | s,a) \\cdot [ r(s,a,s') + \\gamma V_{i}(s')]$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_action_value(mdp, state_values, state, action, gamma):\n",
    "    \"\"\" Computes Q(s,a) as in formula above \"\"\"\n",
    "\n",
    "    <YOUR CODE>\n",
    "\n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "test_Vs = {s: i for i, s in enumerate(sorted(mdp.get_all_states()))}\n",
    "assert np.isclose(get_action_value(mdp, test_Vs, 's2', 'a1', 0.9), 0.69)\n",
    "assert np.isclose(get_action_value(mdp, test_Vs, 's1', 'a0', 0.9), 3.95)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using $Q(s,a)$ we can now define the \"next\" V(s) for value iteration.\n",
    " $$V_{(i+1)}(s) = \\max_a \\sum_{s'} P(s' | s,a) \\cdot [ r(s,a,s') + \\gamma V_{i}(s')] = \\max_a Q_i(s,a)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_new_state_value(mdp, state_values, state, gamma):\n",
    "    \"\"\" Computes next V(s) as in formula above. Please do not change state_values in process. \"\"\"\n",
    "    if mdp.is_terminal(state):\n",
    "        return 0\n",
    "\n",
    "    <YOUR CODE>\n",
    "    \n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_Vs_copy = dict(test_Vs)\n",
    "assert np.isclose(get_new_state_value(mdp, test_Vs, 's0', 0.9), 1.8)\n",
    "assert np.isclose(get_new_state_value(mdp, test_Vs, 's2', 0.9), 1.08)\n",
    "assert np.isclose(get_new_state_value(mdp, {'s0': -1e10, 's1': 0, 's2': -2e10}, 's0', 0.9), -13500000000.0), \\\n",
    "    \"Please ensure that you handle negative Q-values of arbitrary magnitude correctly\"\n",
    "assert test_Vs == test_Vs_copy, \"Please do not change state_values in get_new_state_value\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's combine everything we wrote into a working value iteration algo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "gamma = 0.9            # discount for MDP\n",
    "num_iter = 100         # maximum iterations, excluding initialization\n",
    "# stop VI if new values are this close to old values (or closer)\n",
    "min_difference = 0.001\n",
    "\n",
    "# initialize V(s)\n",
    "state_values = {s: 0 for s in mdp.get_all_states()}\n",
    "\n",
    "if has_graphviz:\n",
    "    display(plot_graph_with_state_values(mdp, state_values))\n",
    "\n",
    "for i in range(num_iter):\n",
    "\n",
    "    # Compute new state values using the functions you defined above.\n",
    "    # It must be a dict {state : float V_new(state)}\n",
    "    new_state_values = <YOUR CODE>\n",
    "\n",
    "    assert isinstance(new_state_values, dict)\n",
    "\n",
    "    # Compute difference\n",
    "    diff = max(abs(new_state_values[s] - state_values[s])\n",
    "               for s in mdp.get_all_states())\n",
    "    print(\"iter %4i   |   diff: %6.5f   |   \" % (i, diff), end=\"\")\n",
    "    print('   '.join(\"V(%s) = %.3f\" % (s, v) for s, v in state_values.items()))\n",
    "    state_values = new_state_values\n",
    "\n",
    "    if diff < min_difference:\n",
    "        print(\"Terminated\")\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if has_graphviz:\n",
    "    display(plot_graph_with_state_values(mdp, state_values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Final state values:\", state_values)\n",
    "\n",
    "assert abs(state_values['s0'] - 3.781) < 0.01\n",
    "assert abs(state_values['s1'] - 7.294) < 0.01\n",
    "assert abs(state_values['s2'] - 4.202) < 0.01"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's use those $V^{*}(s)$ to find optimal actions in each state\n",
    "\n",
    " $$\\pi^*(s) = argmax_a \\sum_{s'} P(s' | s,a) \\cdot [ r(s,a,s') + \\gamma V_{i}(s')] = argmax_a Q_i(s,a)$$\n",
    " \n",
    "The only difference vs V(s) is that here we take not max but argmax: find action such with maximum Q(s,a)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_optimal_action(mdp, state_values, state, gamma=0.9):\n",
    "    \"\"\" Finds optimal action using formula above. \"\"\"\n",
    "    if mdp.is_terminal(state):\n",
    "        return None\n",
    "\n",
    "    <YOUR CODE>\n",
    "\n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert get_optimal_action(mdp, state_values, 's0', gamma) == 'a1'\n",
    "assert get_optimal_action(mdp, state_values, 's1', gamma) == 'a0'\n",
    "assert get_optimal_action(mdp, state_values, 's2', gamma) == 'a1'\n",
    "\n",
    "assert get_optimal_action(mdp, {'s0': -1e10, 's1': 0, 's2': -2e10}, 's0', 0.9) == 'a0', \\\n",
    "    \"Please ensure that you handle negative Q-values of arbitrary magnitude correctly\"\n",
    "assert get_optimal_action(mdp, {'s0': -2e10, 's1': 0, 's2': -1e10}, 's0', 0.9) == 'a1', \\\n",
    "    \"Please ensure that you handle negative Q-values of arbitrary magnitude correctly\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if has_graphviz:\n",
    "    display(plot_graph_optimal_strategy_and_state_values(mdp, state_values, get_action_value))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measure agent's average reward\n",
    "\n",
    "s = mdp.reset()\n",
    "rewards = []\n",
    "for _ in range(10000):\n",
    "    s, r, done, _ = mdp.step(get_optimal_action(mdp, state_values, s, gamma))\n",
    "    rewards.append(r)\n",
    "\n",
    "print(\"average reward: \", np.mean(rewards))\n",
    "\n",
    "assert(0.40 < np.mean(rewards) < 0.55)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Frozen lake"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mdp import FrozenLakeEnv\n",
    "mdp = FrozenLakeEnv(slip_chance=0)\n",
    "\n",
    "mdp.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def value_iteration(mdp, state_values=None, gamma=0.9, num_iter=1000, min_difference=1e-5):\n",
    "    \"\"\" performs num_iter value iteration steps starting from state_values. Same as before but in a function \"\"\"\n",
    "    state_values = state_values or {s: 0 for s in mdp.get_all_states()}\n",
    "    for i in range(num_iter):\n",
    "\n",
    "        # Compute new state values using the functions you defined above. It must be a dict {state : new_V(state)}\n",
    "        new_state_values = <YOUR CODE>\n",
    "\n",
    "        assert isinstance(new_state_values, dict)\n",
    "\n",
    "        # Compute difference\n",
    "        diff = max(abs(new_state_values[s] - state_values[s])\n",
    "                   for s in mdp.get_all_states())\n",
    "\n",
    "        print(\"iter %4i   |   diff: %6.5f   |   V(start): %.3f \" %\n",
    "              (i, diff, new_state_values[mdp._initial_state]))\n",
    "\n",
    "        state_values = new_state_values\n",
    "        if diff < min_difference:\n",
    "            break\n",
    "\n",
    "    return state_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_values = value_iteration(mdp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = mdp.reset()\n",
    "mdp.render()\n",
    "for t in range(100):\n",
    "    a = get_optimal_action(mdp, state_values, s, gamma)\n",
    "    print(a, end='\\n\\n')\n",
    "    s, r, done, _ = mdp.step(a)\n",
    "    mdp.render()\n",
    "    if done:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's visualize!\n",
    "\n",
    "It's usually interesting to see what your algorithm actually learned under the hood. To do so, we'll plot state value functions and optimal actions at each VI step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "def draw_policy(mdp, state_values):\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    h, w = mdp.desc.shape\n",
    "    states = sorted(mdp.get_all_states())\n",
    "    V = np.array([state_values[s] for s in states])\n",
    "    Pi = {s: get_optimal_action(mdp, state_values, s, gamma) for s in states}\n",
    "    plt.imshow(V.reshape(w, h), cmap='gray', interpolation='none', clim=(0, 1))\n",
    "    ax = plt.gca()\n",
    "    ax.set_xticks(np.arange(h)-.5)\n",
    "    ax.set_yticks(np.arange(w)-.5)\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "    Y, X = np.mgrid[0:4, 0:4]\n",
    "    a2uv = {'left': (-1, 0), 'down': (0, -1), 'right': (1, 0), 'up': (0, 1)}\n",
    "    for y in range(h):\n",
    "        for x in range(w):\n",
    "            plt.text(x, y, str(mdp.desc[y, x].item()),\n",
    "                     color='g', size=12,  verticalalignment='center',\n",
    "                     horizontalalignment='center', fontweight='bold')\n",
    "            a = Pi[y, x]\n",
    "            if a is None:\n",
    "                continue\n",
    "            u, v = a2uv[a]\n",
    "            plt.arrow(x, y, u*.3, -v*.3, color='m',\n",
    "                      head_width=0.1, head_length=0.1)\n",
    "    plt.grid(color='b', lw=2, ls='-')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_values = {s: 0 for s in mdp.get_all_states()}\n",
    "\n",
    "for i in range(10):\n",
    "    print(\"after iteration %i\" % i)\n",
    "    state_values = value_iteration(mdp, state_values, num_iter=1)\n",
    "    draw_policy(mdp, state_values)\n",
    "# please ignore iter 0 at each step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "from time import sleep\n",
    "mdp = FrozenLakeEnv(map_name='8x8', slip_chance=0.1)\n",
    "state_values = {s: 0 for s in mdp.get_all_states()}\n",
    "\n",
    "for i in range(30):\n",
    "    clear_output(True)\n",
    "    print(\"after iteration %i\" % i)\n",
    "    state_values = value_iteration(mdp, state_values, num_iter=1)\n",
    "    draw_policy(mdp, state_values)\n",
    "    sleep(0.5)\n",
    "# please ignore iter 0 at each step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Massive tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdp = FrozenLakeEnv(slip_chance=0)\n",
    "state_values = value_iteration(mdp)\n",
    "\n",
    "total_rewards = []\n",
    "for game_i in range(1000):\n",
    "    s = mdp.reset()\n",
    "    rewards = []\n",
    "    for t in range(100):\n",
    "        s, r, done, _ = mdp.step(\n",
    "            get_optimal_action(mdp, state_values, s, gamma))\n",
    "        rewards.append(r)\n",
    "        if done:\n",
    "            break\n",
    "    total_rewards.append(np.sum(rewards))\n",
    "\n",
    "print(\"average reward: \", np.mean(total_rewards))\n",
    "assert(1.0 <= np.mean(total_rewards) <= 1.0)\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measure agent's average reward\n",
    "mdp = FrozenLakeEnv(slip_chance=0.1)\n",
    "state_values = value_iteration(mdp)\n",
    "\n",
    "total_rewards = []\n",
    "for game_i in range(1000):\n",
    "    s = mdp.reset()\n",
    "    rewards = []\n",
    "    for t in range(100):\n",
    "        s, r, done, _ = mdp.step(\n",
    "            get_optimal_action(mdp, state_values, s, gamma))\n",
    "        rewards.append(r)\n",
    "        if done:\n",
    "            break\n",
    "    total_rewards.append(np.sum(rewards))\n",
    "\n",
    "print(\"average reward: \", np.mean(total_rewards))\n",
    "assert(0.8 <= np.mean(total_rewards) <= 0.95)\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measure agent's average reward\n",
    "mdp = FrozenLakeEnv(slip_chance=0.25)\n",
    "state_values = value_iteration(mdp)\n",
    "\n",
    "total_rewards = []\n",
    "for game_i in range(1000):\n",
    "    s = mdp.reset()\n",
    "    rewards = []\n",
    "    for t in range(100):\n",
    "        s, r, done, _ = mdp.step(\n",
    "            get_optimal_action(mdp, state_values, s, gamma))\n",
    "        rewards.append(r)\n",
    "        if done:\n",
    "            break\n",
    "    total_rewards.append(np.sum(rewards))\n",
    "\n",
    "print(\"average reward: \", np.mean(total_rewards))\n",
    "assert(0.6 <= np.mean(total_rewards) <= 0.7)\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Measure agent's average reward\n",
    "mdp = FrozenLakeEnv(slip_chance=0.2, map_name='8x8')\n",
    "state_values = value_iteration(mdp)\n",
    "\n",
    "total_rewards = []\n",
    "for game_i in range(1000):\n",
    "    s = mdp.reset()\n",
    "    rewards = []\n",
    "    for t in range(100):\n",
    "        s, r, done, _ = mdp.step(\n",
    "            get_optimal_action(mdp, state_values, s, gamma))\n",
    "        rewards.append(r)\n",
    "        if done:\n",
    "            break\n",
    "    total_rewards.append(np.sum(rewards))\n",
    "\n",
    "print(\"average reward: \", np.mean(total_rewards))\n",
    "assert(0.6 <= np.mean(total_rewards) <= 0.8)\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HW Part 1: Value iteration convergence\n",
    "\n",
    "### Find an MDP for which value iteration takes long to converge  (1 pts)\n",
    "\n",
    "When we ran value iteration on the small frozen lake problem, the last iteration where an action changed was iteration 6--i.e., value iteration computed the optimal policy at iteration 6. Are there any guarantees regarding how many iterations it'll take value iteration to compute the optimal policy? There are no such guarantees without additional assumptions--we can construct the MDP in such a way that the greedy policy will change after arbitrarily many iterations.\n",
    "\n",
    "Your task: define an MDP with at most 3 states and 2 actions, such that when you run value iteration, the optimal action changes at iteration >= 50. Use discount=0.95. (However, note that the discount doesn't matter here--you can construct an appropriate MDP with any discount.)\n",
    "\n",
    "Note: value function must change at least once after iteration >=50, not necessarily change on every iteration till >=50."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transition_probs = {\n",
    "    <YOUR CODE>\n",
    "}\n",
    "rewards = {\n",
    "    <YOUR CODE>\n",
    "}\n",
    "\n",
    "from mdp import MDP\n",
    "from numpy import random\n",
    "mdp = MDP(transition_probs, rewards, initial_state=random.choice(tuple(transition_probs.keys())))\n",
    "# Feel free to change the initial_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_values = {s: 0 for s in mdp.get_all_states()}\n",
    "policy = np.array([get_optimal_action(mdp, state_values, state, gamma)\n",
    "                   for state in sorted(mdp.get_all_states())])\n",
    "\n",
    "for i in range(100):\n",
    "    print(\"after iteration %i\" % i)\n",
    "    state_values = value_iteration(mdp, state_values, num_iter=1)\n",
    "\n",
    "    new_policy = np.array([get_optimal_action(mdp, state_values, state, gamma)\n",
    "                           for state in sorted(mdp.get_all_states())])\n",
    "\n",
    "    n_changes = (policy != new_policy).sum()\n",
    "    print(\"N actions changed = %i \\n\" % n_changes)\n",
    "    policy = new_policy\n",
    "\n",
    "# please ignore iter 0 at each step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value iteration convervence proof (1 pts)\n",
    "**Note:** Assume that $\\mathcal{S}, \\mathcal{A}$ are finite.\n",
    "\n",
    "Update of value function in value iteration can be rewritten in a form of Bellman operator:\n",
    "\n",
    "$$(TV)(s) = \\max_{a \\in \\mathcal{A}}\\mathbb{E}\\left[ r_{t+1} + \\gamma V(s_{t+1}) | s_t = s, a_t = a\\right]$$\n",
    "\n",
    "Value iteration algorithm with Bellman operator:\n",
    "\n",
    "---\n",
    "&nbsp;&nbsp; Initialize $V_0$\n",
    "\n",
    "&nbsp;&nbsp; **for** $k = 0,1,2,...$ **do**\n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp; $V_{k+1} \\leftarrow TV_k$\n",
    "\n",
    "&nbsp;&nbsp;**end for**\n",
    "\n",
    "---\n",
    "\n",
    "In [lecture](https://docs.google.com/presentation/d/1lz2oIUTvd2MHWKEQSH8hquS66oe4MZ_eRvVViZs2uuE/edit#slide=id.g4fd6bae29e_2_4) we established contraction property of bellman operator:\n",
    "\n",
    "$$\n",
    "||TV - TU||_{\\infty} \\le \\gamma ||V - U||_{\\infty}\n",
    "$$\n",
    "\n",
    "For all $V, U$\n",
    "\n",
    "Using contraction property of Bellman operator, Banach fixed-point theorem and Bellman equations prove that value function converges to $V^*$ in value iterateion$\n",
    "\n",
    "*<-- Your proof here -->*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Asynchronious value iteration (2 pts)\n",
    "\n",
    "Consider the following algorithm:\n",
    "\n",
    "---\n",
    "\n",
    "Initialize $V_0$\n",
    "\n",
    "**for** $k = 0,1,2,...$ **do**\n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp; Select some state $s_k \\in \\mathcal{S}$    \n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp; $V(s_k) := (TV)(s_k)$\n",
    "\n",
    "**end for**\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "Note that unlike common value iteration, here we update only a single state at a time.\n",
    "\n",
    "**Homework.** Prove the following proposition:\n",
    "\n",
    "If for all $s \\in \\mathcal{S}$, $s$ appears in the sequence $(s_0, s_1, ...)$ infinitely often, then $V$ converges to $V*$\n",
    "\n",
    "*<-- Your proof here -->*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HW Part 2: Policy iteration\n",
    "\n",
    "## Policy iteration implementateion (3 pts)\n",
    "\n",
    "Let's implement exact policy iteration (PI), which has the following pseudocode:\n",
    "\n",
    "---\n",
    "Initialize $\\pi_0$   `// random or fixed action`\n",
    "\n",
    "For $n=0, 1, 2, \\dots$\n",
    "- Compute the state-value function $V^{\\pi_{n}}$\n",
    "- Using $V^{\\pi_{n}}$, compute the state-action-value function $Q^{\\pi_{n}}$\n",
    "- Compute new policy $\\pi_{n+1}(s) = \\operatorname*{argmax}_a Q^{\\pi_{n}}(s,a)$\n",
    "---\n",
    "\n",
    "Unlike VI, policy iteration has to maintain a policy - chosen actions from all states - and estimate $V^{\\pi_{n}}$ based on this policy. It only changes policy once values converged.\n",
    "\n",
    "\n",
    "Below are a few helpers that you may or may not use in your implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transition_probs = {\n",
    "    's0': {\n",
    "        'a0': {'s0': 0.5, 's2': 0.5},\n",
    "        'a1': {'s2': 1}\n",
    "    },\n",
    "    's1': {\n",
    "        'a0': {'s0': 0.7, 's1': 0.1, 's2': 0.2},\n",
    "        'a1': {'s1': 0.95, 's2': 0.05}\n",
    "    },\n",
    "    's2': {\n",
    "        'a0': {'s0': 0.4, 's1': 0.6},\n",
    "        'a1': {'s0': 0.3, 's1': 0.3, 's2': 0.4}\n",
    "    }\n",
    "}\n",
    "rewards = {\n",
    "    's1': {'a0': {'s0': +5}},\n",
    "    's2': {'a1': {'s0': -1}}\n",
    "}\n",
    "\n",
    "from mdp import MDP\n",
    "mdp = MDP(transition_probs, rewards, initial_state='s0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's write a function called `compute_vpi` that computes the state-value function $V^{\\pi}$ for an arbitrary policy $\\pi$.\n",
    "\n",
    "Unlike VI, this time you must find the exact solution, not just a single iteration.\n",
    "\n",
    "Recall that $V^{\\pi}$ satisfies the following linear equation:\n",
    "$$V^{\\pi}(s) = \\sum_{s'} P(s,\\pi(s),s')[ R(s,\\pi(s),s') + \\gamma V^{\\pi}(s')]$$\n",
    "\n",
    "You'll have to solve a linear system in your code. (Find an exact solution, e.g., with `np.linalg.solve`.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_vpi(mdp, policy, gamma):\n",
    "    \"\"\"\n",
    "    Computes V^pi(s) FOR ALL STATES under given policy.\n",
    "    :param policy: a dict of currently chosen actions {s : a}\n",
    "    :returns: a dict {state : V^pi(state) for all states}\n",
    "    \"\"\"\n",
    "    <YOUR CODE>\n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_policy = {s: np.random.choice(\n",
    "    mdp.get_possible_actions(s)) for s in mdp.get_all_states()}\n",
    "new_vpi = compute_vpi(mdp, test_policy, gamma)\n",
    "\n",
    "print(new_vpi)\n",
    "\n",
    "assert type(\n",
    "    new_vpi) is dict, \"compute_vpi must return a dict {state : V^pi(state) for all states}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we've got new state values, it's time to update our policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_new_policy(mdp, vpi, gamma):\n",
    "    \"\"\"\n",
    "    Computes new policy as argmax of state values\n",
    "    :param vpi: a dict {state : V^pi(state) for all states}\n",
    "    :returns: a dict {state : optimal action for all states}\n",
    "    \"\"\"\n",
    "    <YOUR CODE>\n",
    "    return <YOUR CODE>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_policy = compute_new_policy(mdp, new_vpi, gamma)\n",
    "\n",
    "print(new_policy)\n",
    "\n",
    "assert type(\n",
    "    new_policy) is dict, \"compute_new_policy must return a dict {state : optimal action for all states}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Main loop__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy_iteration(mdp, policy=None, gamma=0.9, num_iter=1000, min_difference=1e-5):\n",
    "    \"\"\" \n",
    "    Run the policy iteration loop for num_iter iterations or till difference between V(s) is below min_difference.\n",
    "    If policy is not given, initialize it at random.\n",
    "    \"\"\"\n",
    "    <YOUR CODE: a whole lot of it>\n",
    "\n",
    "    return state_values, policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Your PI Results__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "<YOUR CODE: compare PI and VI on the MDP from bonus 1, then on small & large FrozenLake>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Policy iteration convergence (3 pts)\n",
    "\n",
    "**Note:** Assume that $\\mathcal{S}, \\mathcal{A}$ are finite.\n",
    "\n",
    "We can define another Bellman operator:\n",
    "\n",
    "$$(T_{\\pi}V)(s) = \\mathbb{E}_{r, s'|s, a = \\pi(s)}\\left[r + \\gamma V(s')\\right]$$\n",
    "\n",
    "And rewrite policy iteration algorithm in operator form:\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "Initialize $\\pi_0$\n",
    "\n",
    "**for** $k = 0,1,2,...$ **do**\n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp; Solve $V_k = T_{\\pi_k}V_k$   \n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp; Select $\\pi_{k+1}$ s.t. $T_{\\pi_{k+1}}V_k = TV_k$ \n",
    "\n",
    "**end for**\n",
    "\n",
    "---\n",
    "\n",
    "To prove convergence of the algorithm we need to prove two properties: contraction an monotonicity.\n",
    "\n",
    "#### Monotonicity (0.5 pts)\n",
    "\n",
    "For all $V, U$ if $V(s) \\le U(s)$   $\\forall s \\in \\mathcal{S}$ then $(T_\\pi V)(s) \\le (T_\\pi U)(s)$   $\\forall s \\in  \\mathcal{S}$\n",
    "\n",
    "*<-- Your proof here -->*\n",
    "\n",
    "#### Contraction (1 pts)\n",
    "\n",
    "$$\n",
    "||T_\\pi V - T_\\pi U||_{\\infty} \\le \\gamma ||V - U||_{\\infty}\n",
    "$$\n",
    "\n",
    "For all $V, U$\n",
    "\n",
    "*<-- Your proof here -->*\n",
    "\n",
    "#### Convergence (1.5 pts)\n",
    "\n",
    "Prove that there exists iteration $k_0$ such that $\\pi_k = \\pi^*$ for all $k \\ge k_0$\n",
    "\n",
    "*<-- Your proof here -->*"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
